import argparse
import os

print('Parsing args')

parser = argparse.ArgumentParser()
parser.add_argument("--bs", type=int, default=32)
parser.add_argument("--num_epochs", type=int, default=50)
parser.add_argument("--n_frequency", type=int, default=200)
parser.add_argument("--head_lr", type=float, default=5e-3)
parser.add_argument("--frequency_lr", type=float, default=1e-1)
parser.add_argument("--location_lr", type=float, default=1e-4)
parser.add_argument("--learn_location_iter", type=int, default=500)
parser.add_argument("--weight_decay", type=float, default=0.0)

## For ViT computer vision tasks
parser.add_argument("--model-name-or-path", type=str,
                    required=True,
                    choices=[
                        "google/vit-base-patch16-224-in21k",
                        "google/vit-large-patch16-224-in21k",
                        "google/vit-huge-patch14-224-in21k",
                    ])
parser.add_argument("--dataset-name", type=str,
                    required=True,
                    choices=[
                        "pets",
                        "dtd",
                        "resisc",
                        "eurosat",
                        "cars",
                        "fgvc",
                        "cifar10",
                        "cifar100",
                    ])

parser.add_argument("--mode", type=str, choices=["fourier", "lora", "head", "full", "loca"])

parser.add_argument("--lora-r", type=int, default=16)
parser.add_argument("--lora-alpha", type=int, default=16)
parser.add_argument("--lora-dropout", type=float, default=0)
parser.add_argument("--loca_dropout", type=float, default=0)
parser.add_argument("--scale", type=float, default=1.0)
parser.add_argument("--loca_dct_mode", type=str, default='default', choices=["default", "sparse", "fast"])

parser.add_argument("--n_trial", type=int, default=1)

parser.add_argument("--results-dir", type=str, default="results")
parser.add_argument("--cache-dir", type=str, default=os.path.join(os.getenv("HF_HOME"), ".cache"))
parser.add_argument("--data-local-dir", type=str, default=None)


def get_args():
    return parser.parse_args()